# -*- coding: utf-8 -*-
'''
使用方法：
test(trainfile,(m,n),splittype)(m,n)分别为容许的误差下降值，切分的最少样本数。测试出误差值越接近1越好

predata(trainfile,prefile,outfile,(m,n),splittype)进行数据预测
'''
from numpy import *
import sys

def loadDataSet(fileName,splittype):     
    dataMat = [] 
    fr = open(fileName)
    for line in fr.readlines():
        curLine = line.strip().split(splittype)
        fltLine = map(float,curLine) 
        dataMat.append(fltLine)
    return dataMat

def binSplitDataSet(dataSet, feature, value):
    mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:][0]
    mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:][0]
    return mat0,mat1

def regLeaf(dataSet):
    return mean(dataSet[:,-1])

def regErr(dataSet):
    return var(dataSet[:,-1]) * shape(dataSet)[0]

def linearSolve(dataSet):   
    m,n = shape(dataSet)
    X = mat(ones((m,n))); Y = mat(ones((m,1)))
    X[:,1:n] = dataSet[:,0:n-1]; Y = dataSet[:,-1]
    xTx = X.T*X
    if linalg.det(xTx) == 0.0:
        raise NameError('This matrix is singular, cannot do inverse,\n\
        try increasing the second value of ops')
    ws = xTx.I * (X.T * Y)
    return ws,X,Y

def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    tolS = ops[0]; tolN = ops[1]
    if len(set(dataSet[:,-1].T.tolist()[0])) == 1: 
        return None, leafType(dataSet)
    m,n = shape(dataSet)
    S = errType(dataSet)
    bestS = inf; bestIndex = 0; bestValue = 0
    for featIndex in range(n-1):
        for splitVal in set(dataSet[:,featIndex]):
            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
            if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue
            newS = errType(mat0) + errType(mat1)
            if newS < bestS: 
                bestIndex = featIndex
                bestValue = splitVal
                bestS = newS
    if (S - bestS) < tolS: 
        return None, leafType(dataSet) 
    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
    if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): 
        return None, leafType(dataSet)
    return bestIndex,bestValue

def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
    if feat == None: return val 
    retTree = {}
    retTree['spInd'] = feat
    retTree['spVal'] = val
    lSet, rSet = binSplitDataSet(dataSet, feat, val)
    retTree['left'] = createTree(lSet, leafType, errType, ops)
    retTree['right'] = createTree(rSet, leafType, errType, ops)
    return retTree  

def isTree(obj):
    return (type(obj).__name__=='dict')
    
def regTreeEval(model, inDat):
    return float(model)

def treeForeCast(tree, inData, modelEval=regTreeEval):
    if not isTree(tree): return modelEval(tree, inData)
    if inData[tree['spInd']] > tree['spVal']:
        if isTree(tree['left']): return treeForeCast(tree['left'], inData, modelEval)
        else: return modelEval(tree['left'], inData)
    else:
        if isTree(tree['right']): return treeForeCast(tree['right'], inData, modelEval)
        else: return modelEval(tree['right'], inData)
        
def createForeCast(tree, testData, modelEval=regTreeEval):
    m=len(testData)
    yHat = mat(zeros((m,1)))
    for i in range(m):
        yHat[i,0] = treeForeCast(tree, mat(testData[i]), modelEval)
    return yHat

def test(trainfile,m,n,splittype):
    print "running...",
    datamat = mat(loadDataSet(trainfile,splittype))
    mytree = createTree(datamat,ops=(m,n))
    yhat = createForeCast(mytree,datamat[:,0:-1])
    print "越接近1正确率越高"
    print corrcoef(yhat,datamat[:,-1],rowvar=0)[0,1]
    print "end..."

def predata(trainfile,prefile,outfile,m,n,splittype):
    print 'running',
    datamat = mat(loadDataSet(trainfile,splittype))
    mytree = createTree(datamat,ops=(m,n))
    fr = open(prefile)
    yhat = []
    fw = open(outfile,'a')
    for line in fr.readlines():
        datas = line.strip().split(splittype)
        for i in range(len(datas)):
            print '>',
            yhat.append(float(datas[i]))
        predata = treeForeCast(mytree,mat(yhat),regTreeEval)
        yhat = []
        fw.write(line.strip() + splittype + str(predata) +'\n')
    fw.flush()
    fw.close()
    fr.close()
    print "end..."

if __name__ == "__main__":
	if len(sys.argv) == 5:
		if sys.argv[4] == "\\t":
			test(sys.argv[1],int(sys.argv[2]),int(sys.argv[3]),"\t")
		else:
			test(sys.argv[1],int(sys.argv[2]),int(sys.argv[3]),sys.argv[4])
	elif len(sys.argv) == 7:
		if sys.argv[6] == "\\t":
			predata(sys.argv[1],sys.argv[2],sys.argv[3],int(sys.argv[4]),int(sys.argv[5],"\t")
		else:
			predata(sys.argv[1],sys.argv[2],sys.argv[3],int(sys.argv[4]),int(sys.argv[5]),sys.argv[6])
